# 相关scatter的操作
# 按照索引 将目标值移到 目标张量上

import torch
pred = torch.tensor([[1, 1, 2, 3],
                    [1, 1, 2, 3]])
target = torch.tensor([0, 1])
one_hot = torch.zeros_like(pred).scatter(dim=1, index=target.view(-1, 1), value=1)
print(one_hot)
print("-"*20)

target = torch.tensor([0, 1, -1, 2])  # 真实标签
ignore_index = -1                     # 设定要忽略的值

valid_mask = target != ignore_index
print(valid_mask)  # 输出：tensor([ True,  True, False,  True])

print(target[valid_mask])
